/* ----------------------------------------------------------------------
   LAMMPS - Large-scale Atomic/Molecular Massively Parallel Simulator
   https://www.lammps.org/, Sandia National Laboratories
   LAMMPS development team: developers@lammps.org

   Copyright (2003) Sandia Corporation.  Under the terms of Contract
   DE-AC04-94AL85000 with Sandia Corporation, the U.S. Government retains
   certain rights in this software.  This software is distributed under
   the GNU General Public License.

   See the README file in the top-level LAMMPS directory.
------------------------------------------------------------------------- */

/* ----------------------------------------------------------------------
   Contributing author: Anders Hafreager (UiO)
------------------------------------------------------------------------- */

#include "pair_vashishta_gpu.h"

#include "atom.h"
#include "comm.h"
#include "domain.h"
#include "error.h"
#include "info.h"
#include "gpu_extra.h"
#include "memory.h"
#include "neigh_list.h"
#include "neighbor.h"
#include "suffix.h"

using namespace LAMMPS_NS;

// External functions from cuda library for atom decomposition

int vashishta_gpu_init(const int ntypes, const int inum, const int nall, const int max_nbors,
                       const double cell_size, int &gpu_mode, FILE *screen, int *host_map,
                       const int nelements, int ***host_elem3param, const int nparams,
                       const double *cutsq, const double *r0, const double *gamma,
                       const double *eta, const double *lam1inv, const double *lam4inv,
                       const double *zizj, const double *mbigd, const double *dvrc,
                       const double *big6w, const double *heta, const double *bigh,
                       const double *bigw, const double *c0, const double *costheta,
                       const double *bigb, const double *big2b, const double *bigc);
void vashishta_gpu_clear();
int **vashishta_gpu_compute_n(const int ago, const int inum, const int nall, double **host_x,
                              int *host_type, double *sublo, double *subhi, tagint *tag,
                              int **nspecial, tagint **special, const bool eflag, const bool vflag,
                              const bool eatom, const bool vatom, int &host_start, int **ilist,
                              int **jnum, const double cpu_time, bool &success);
void vashishta_gpu_compute(const int ago, const int nloc, const int nall, const int ln,
                           double **host_x, int *host_type, int *ilist, int *numj, int **firstneigh,
                           const bool eflag, const bool vflag, const bool eatom, const bool vatom,
                           int &host_start, const double cpu_time, bool &success);
double vashishta_gpu_bytes();

/* ---------------------------------------------------------------------- */

PairVashishtaGPU::PairVashishtaGPU(LAMMPS *lmp) : PairVashishta(lmp), gpu_mode(GPU_FORCE)
{
  cpu_time = 0.0;
  reinitflag = 0;
  gpu_allocated = false;
  suffix_flag |= Suffix::GPU;
  GPU_EXTRA::gpu_ready(lmp->modify, lmp->error);

  cutghost = nullptr;
  ghostneigh = 1;
}

/* ----------------------------------------------------------------------
   check if allocated, since class can be destructed when incomplete
------------------------------------------------------------------------- */

PairVashishtaGPU::~PairVashishtaGPU()
{
  vashishta_gpu_clear();
  if (allocated) memory->destroy(cutghost);
}

/* ---------------------------------------------------------------------- */

void PairVashishtaGPU::compute(int eflag, int vflag)
{
  ev_init(eflag, vflag);

  int nall = atom->nlocal + atom->nghost;
  int inum, host_start;

  bool success = true;
  int *ilist, *numneigh, **firstneigh;
  if (gpu_mode != GPU_FORCE) {
    double sublo[3], subhi[3];
    if (domain->triclinic == 0) {
      sublo[0] = domain->sublo[0];
      sublo[1] = domain->sublo[1];
      sublo[2] = domain->sublo[2];
      subhi[0] = domain->subhi[0];
      subhi[1] = domain->subhi[1];
      subhi[2] = domain->subhi[2];
    } else {
      domain->bbox(domain->sublo_lamda, domain->subhi_lamda, sublo, subhi);
    }
    inum = atom->nlocal;
    firstneigh =
        vashishta_gpu_compute_n(neighbor->ago, inum, nall, atom->x, atom->type, sublo, subhi,
                                atom->tag, atom->nspecial, atom->special, eflag, vflag, eflag_atom,
                                vflag_atom, host_start, &ilist, &numneigh, cpu_time, success);
  } else {
    inum = list->inum;
    ilist = list->ilist;
    numneigh = list->numneigh;
    firstneigh = list->firstneigh;

    vashishta_gpu_compute(neighbor->ago, inum, nall, inum + list->gnum, atom->x, atom->type, ilist,
                          numneigh, firstneigh, eflag, vflag, eflag_atom, vflag_atom, host_start,
                          cpu_time, success);
  }
  if (!success) error->one(FLERR, "Insufficient memory on accelerator");
  if (atom->molecular != Atom::ATOMIC && neighbor->ago == 0)
    neighbor->build_topology();
}

/* ---------------------------------------------------------------------- */

void PairVashishtaGPU::allocate()
{
  if (!allocated) { PairVashishta::allocate(); }
  int n = atom->ntypes;

  memory->create(cutghost, n + 1, n + 1, "pair:cutghost");
  gpu_allocated = true;
}

/* ----------------------------------------------------------------------
   init specific to this pair style
------------------------------------------------------------------------- */

void PairVashishtaGPU::init_style()
{
  double cell_size = cutmax + neighbor->skin;

  if (atom->tag_enable == 0) error->all(FLERR, "Pair style vashishta/gpu requires atom IDs");

  double *cutsq, *r0, *gamma, *eta;
  double *lam1inv, *lam4inv, *zizj, *mbigd;
  double *dvrc, *big6w, *heta, *bigh;
  double *bigw, *c0, *costheta, *bigb;
  double *big2b, *bigc;

  cutsq = r0 = gamma = eta = nullptr;
  lam1inv = lam4inv = zizj = mbigd = nullptr;
  dvrc = big6w = heta = bigh = nullptr;
  bigw = c0 = costheta = bigb = nullptr;
  big2b = bigc = nullptr;

  memory->create(cutsq, nparams, "pair:cutsq");
  memory->create(r0, nparams, "pair:r0");
  memory->create(gamma, nparams, "pair:gamma");
  memory->create(eta, nparams, "pair:eta");
  memory->create(lam1inv, nparams, "pair:lam1inv");
  memory->create(lam4inv, nparams, "pair:lam4inv");
  memory->create(zizj, nparams, "pair:zizj");
  memory->create(mbigd, nparams, "pair:mbigd");
  memory->create(dvrc, nparams, "pair:dvrc");
  memory->create(big6w, nparams, "pair:big6w");
  memory->create(heta, nparams, "pair:heta");
  memory->create(bigh, nparams, "pair:bigh");
  memory->create(bigw, nparams, "pair:bigw");
  memory->create(c0, nparams, "pair:c0");
  memory->create(costheta, nparams, "pair:costheta");
  memory->create(bigb, nparams, "pair:bigb");
  memory->create(big2b, nparams, "pair:big2b");
  memory->create(bigc, nparams, "pair:bigc");

  for (int i = 0; i < nparams; i++) {
    cutsq[i] = params[i].cutsq;
    r0[i] = params[i].r0;
    gamma[i] = params[i].gamma;
    eta[i] = params[i].eta;
    lam1inv[i] = params[i].lam1inv;
    lam4inv[i] = params[i].lam4inv;
    zizj[i] = params[i].zizj;
    mbigd[i] = params[i].mbigd;
    dvrc[i] = params[i].dvrc;
    big6w[i] = params[i].big6w;
    heta[i] = params[i].heta;
    bigh[i] = params[i].bigh;
    bigw[i] = params[i].bigw;
    c0[i] = params[i].c0;
    costheta[i] = params[i].costheta;
    bigb[i] = params[i].bigb;
    big2b[i] = params[i].big2b;
    bigc[i] = params[i].bigc;
  }
  int mnf = 5e-2 * neighbor->oneatom;
  int success = vashishta_gpu_init(atom->ntypes + 1, atom->nlocal, atom->nlocal + atom->nghost, mnf,
                                   cell_size, gpu_mode, screen, map, nelements, elem3param, nparams,
                                   cutsq, r0, gamma, eta, lam1inv, lam4inv, zizj, mbigd, dvrc,
                                   big6w, heta, bigh, bigw, c0, costheta, bigb, big2b, bigc);
  memory->destroy(cutsq);
  memory->destroy(r0);
  memory->destroy(gamma);
  memory->destroy(eta);
  memory->destroy(lam1inv);
  memory->destroy(lam4inv);
  memory->destroy(zizj);
  memory->destroy(mbigd);
  memory->destroy(dvrc);
  memory->destroy(big6w);
  memory->destroy(heta);
  memory->destroy(bigh);
  memory->destroy(bigw);
  memory->destroy(c0);
  memory->destroy(costheta);
  memory->destroy(bigb);
  memory->destroy(big2b);
  memory->destroy(bigc);

  GPU_EXTRA::check_flag(success, error, world);

  if (gpu_mode == GPU_FORCE)
    neighbor->add_request(this, NeighConst::REQ_FULL | NeighConst::REQ_GHOST);
  if (comm->get_comm_cutoff() < (2.0 * cutmax + neighbor->skin)) {
    comm->cutghostuser = 2.0 * cutmax + neighbor->skin;
    if (comm->me == 0)
      error->warning(FLERR, "Increasing communication cutoff to {:.8} for GPU pair style",
                     comm->cutghostuser);
  }
}

/* ----------------------------------------------------------------------
   init for one type pair i,j and corresponding j,i
------------------------------------------------------------------------- */

double PairVashishtaGPU::init_one(int i, int j)
{
  if (!gpu_allocated) allocate();
  if (setflag[i][j] == 0)
    error->all(FLERR, Error::NOLASTLINE,
               "All pair coeffs are not set. Status:\n" + Info::get_pair_coeff_status(lmp));
  cutghost[i][j] = cutmax;
  cutghost[j][i] = cutmax;

  return cutmax;
}
